{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision.utils as utils\n",
    "import pytorch_ssim\n",
    "import  time \n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torch.nn.modules.loss import _Loss \n",
    "from Trans_unet_Gan import *\n",
    "#from dataset import prepare_data, Dataset\n",
    "from utils import *\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split(img):\n",
    "    output=[]\n",
    "    output.append(F.interpolate(img, scale_factor=0.125))\n",
    "    output.append(F.interpolate(img, scale_factor=0.25))\n",
    "    output.append(F.interpolate(img, scale_factor=0.5))\n",
    "    output.append(img)\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dtype = 'float32'\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
    "torch.set_default_tensor_type(torch.FloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_x=[]\n",
    "path='./final/train/input/'#要改\n",
    "path_list = os.listdir(path)\n",
    "path_list.sort(key=lambda x:int(x.split('.')[0]))\n",
    "for item in path_list:\n",
    "    impath=path+item\n",
    "    #print(\"开始处理\"+impath)\n",
    "    imgx= cv2.imread(path+item)\n",
    "    imgx = cv2.cvtColor(imgx, cv2.COLOR_BGR2RGB)\n",
    "    imgx=cv2.resize(imgx,(256,256))\n",
    "    #imgx=imgx/255.0\n",
    "    training_x.append(imgx)\n",
    "    \n",
    "    \n",
    "    \n",
    "\n",
    "X_train = []\n",
    "\n",
    "for features in training_x:\n",
    "    X_train.append(features)\n",
    "    \n",
    "\n",
    "    \n",
    "#X_train = np.array(X_train).reshape(-1,3,256,256)\n",
    "X_train = np.array(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train=X_train.astype(dtype)\n",
    "X_train= torch.from_numpy(X_train)\n",
    "X_train=X_train.permute(0,3,1,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([9820, 3, 256, 256])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train=X_train/255.0\n",
    "X_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_y=[]\n",
    "path='./final/train/GT/'#要改\n",
    "path_list = os.listdir(path)\n",
    "path_list.sort(key=lambda x:int(x.split('.')[0]))\n",
    "for item in path_list:\n",
    "    impath=path+item\n",
    "    #print(\"开始处理\"+impath)\n",
    "    imgx= cv2.imread(path+item)\n",
    "    imgx = cv2.cvtColor(imgx, cv2.COLOR_BGR2RGB)\n",
    "    imgx=cv2.resize(imgx,(256,256))\n",
    "    #imgx=imgx/255.0\n",
    "    training_y.append(imgx)\n",
    "    \n",
    "    \n",
    "    \n",
    "\n",
    "y_train = []\n",
    "\n",
    "for features in training_y:\n",
    "    y_train.append(features)\n",
    "    \n",
    "\n",
    "    \n",
    "y_train = np.array(y_train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train=y_train.astype(dtype)\n",
    "y_train= torch.from_numpy(y_train)\n",
    "y_train=y_train.permute(0,3,1,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([9820, 3, 256, 256])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train=y_train/255.0\n",
    "y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_x=[]\n",
    "path='./final/test/input/'#要改\n",
    "path_list = os.listdir(path)\n",
    "path_list.sort(key=lambda x:int(x.split('.')[0]))\n",
    "for item in path_list:\n",
    "    impath=path+item\n",
    "    #print(\"开始处理\"+impath)\n",
    "    imgx= cv2.imread(path+item)\n",
    "    imgx = cv2.cvtColor(imgx, cv2.COLOR_BGR2RGB)\n",
    "    imgx=cv2.resize(imgx,(256,256))\n",
    "    #imgy=imgy/255.0\n",
    "    test_x.append(imgx)\n",
    "    \n",
    "    \n",
    "    \n",
    "\n",
    "x_test = []\n",
    "\n",
    "for features in test_x:\n",
    "    x_test.append(features)\n",
    "    \n",
    "\n",
    "    \n",
    "x_test = np.array(x_test)\n",
    "#X_train = np.array(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_test=x_test.astype(dtype)\n",
    "x_test= torch.from_numpy(x_test)\n",
    "x_test=x_test.permute(0,3,1,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_test=x_test/255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([500, 3, 256, 256])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_Y=[]\n",
    "path='./final/test/GT/'#要改\n",
    "path_list = os.listdir(path)\n",
    "path_list.sort(key=lambda x:int(x.split('.')[0]))\n",
    "for item in path_list:\n",
    "    impath=path+item\n",
    "    #print(\"开始处理\"+impath)\n",
    "    imgx= cv2.imread(path+item)\n",
    "    imgx = cv2.cvtColor(imgx, cv2.COLOR_BGR2RGB)\n",
    "    imgx=cv2.resize(imgx,(256,256))\n",
    "    #imgy=imgy/255.0\n",
    "    test_Y.append(imgx)\n",
    "    \n",
    "    \n",
    "    \n",
    "\n",
    "Y_test = []\n",
    "\n",
    "for features in test_Y:\n",
    "    Y_test.append(features)\n",
    "    \n",
    "\n",
    "    \n",
    "Y_test = np.array(Y_test)\n",
    "#X_train = np.array(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_test=Y_test.astype(dtype)\n",
    "Y_test= torch.from_numpy(Y_test)\n",
    "Y_test=Y_test.permute(0,3,1,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_test=Y_test/255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([500, 3, 256, 256])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data as dataf\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "\n",
    "dataset = dataf.TensorDataset(X_train,y_train)\n",
    "loader = dataf.DataLoader(dataset, batch_size=6, shuffle=True,num_workers=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bit/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "MSELoss()"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "criterion_GAN=nn.MSELoss(size_average=False)\n",
    "criterion_pixelwise=nn.MSELoss(size_average=False)\n",
    "MSE = nn.MSELoss(size_average=False)\n",
    "SSIM = pytorch_ssim.SSIM()\n",
    "L_vgg = VGG19_PercepLoss()\n",
    "criterion_pixelwise.cuda()\n",
    "L_vgg = L_vgg.cuda()\n",
    "MSE.cuda()\n",
    "SSIM.cuda()\n",
    "criterion_GAN.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loss weight of L1 pixel-wise loss between translated image and real image\n",
    "lambda_pixel =0.1\n",
    "lambda_con = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate output of image discriminator (PatchGAN)\n",
    "patch = (1, 256 // 2 ** 5, 256// 2 ** 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize generator and discriminator\n",
    "generator = Generator().cuda()\n",
    "discriminator = Discriminator().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "无保存模型，将从头开始训练！\n"
     ]
    }
   ],
   "source": [
    "# 如果有保存的模型，则加载模型，并在其基础上继续训练\n",
    "use_pretrain=False\n",
    "if use_pretrain:\n",
    "    # Load pretrained models\n",
    "    start_epoch=0\n",
    "    generator.load_state_dict(torch.load(\"saved_models/G/generator_%d.pth\" % (start_epoch)))\n",
    "    discriminator.load_state_dict(torch.load(\"saved_models/D/discriminator_%d.pth\" % (start_epoch)))\n",
    "    print('加载 epoch {} 成功！'.format(start_epoch))\n",
    "else:\n",
    "    start_epoch = 0\n",
    "    print('无保存模型，将从头开始训练！')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_images(batches_done):\n",
    "    \"\"\"Saves a generated sample from the validation set\"\"\"\n",
    "    generator.eval()\n",
    "    i=random.randrange(1,499)\n",
    "    real_A = Variable(x_test[i,:,:,:]).cuda()\n",
    "    real_B = Variable(Y_test[i,:,:,:]).cuda()\n",
    "    real_A=real_A.unsqueeze(0)\n",
    "    real_B=real_B.unsqueeze(0)\n",
    "    fake_B = generator(real_A)\n",
    "    #print(fake_B.shape)\n",
    "    imgx=fake_B[3].data\n",
    "    imgy=real_B.data\n",
    "    x=imgx[:,:,:,:]\n",
    "    y=imgy[:,:,:,:]\n",
    "    img_sample = torch.cat((x,y), -2)\n",
    "    save_image(img_sample, \"images/%s/%s.png\" % ('results', batches_done), nrow=5, normalize=True)#要改"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import lr_scheduler\n",
    "LR=0.0005\n",
    "\n",
    "# Optimizers\n",
    "optimizer_G = torch.optim.Adam(generator.parameters(), lr=LR,  betas=(0.5, 0.999))\n",
    "optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=LR,  betas=(0.5, 0.999))\n",
    "scheduler_G=lr_scheduler.StepLR(optimizer_G,step_size=40,gamma=0.8)\n",
    "scheduler_D=lr_scheduler.StepLR(optimizer_D,step_size=40,gamma=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda = True if torch.cuda.is_available() else False\n",
    "Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bit/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/functional.py:3502: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. \n",
      "  warnings.warn(\n",
      "/home/bit/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/functional.py:3454: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n",
      "  warnings.warn(\n",
      "/home/bit/plt/Ushape_net_ourdata/utils.py:51: UserWarning: DEPRECATED: skimage.measure.compare_psnr has been moved to skimage.metrics.peak_signal_noise_ratio. It will be removed from skimage.measure in version 0.18.\n",
      "  PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 4/500] [Batch 452/1637][PSNR: 19.826566] [SSIM: 0.665947] [D loss: 3.313898] [G loss: 1057.849121], [pixel: 723.990576],[VGG_loss: 82.768810], [adv: 317.684509] ETA: 5 days, 11:36:55.675378830872"
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "index 513 is out of bounds for dimension 0 with size 500",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-28-ffbbaeb2c034>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m    138\u001b[0m         \u001b[0;31m# If at sample interval save image\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    139\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mbatches_done\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0msample_interval\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m             \u001b[0msample_images\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatches_done\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    141\u001b[0m             \u001b[0mcsv_writer1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwriterow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsnr_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    142\u001b[0m             \u001b[0mcsv_writer2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwriterow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mssim_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-25-581abb5353ff>\u001b[0m in \u001b[0;36msample_images\u001b[0;34m(batches_done)\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mgenerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0mi\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m514\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m     \u001b[0mreal_A\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_test\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m     \u001b[0mreal_B\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mY_test\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mreal_A\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreal_A\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mIndexError\u001b[0m: index 513 is out of bounds for dimension 0 with size 500"
     ]
    }
   ],
   "source": [
    "# ----------\n",
    "#  Training\n",
    "# ----------\n",
    "import time as time\n",
    "import datetime\n",
    "import sys\n",
    "from torchvision.utils import save_image\n",
    "import csv\n",
    "import random\n",
    "\n",
    "\n",
    "f1 = open('psnr.csv','w',encoding='utf-8')#要改\n",
    "csv_writer1 = csv.writer(f1)\n",
    "f2 = open('SSIM.csv','w',encoding='utf-8')#要改\n",
    "csv_writer2 = csv.writer(f2)\n",
    "\n",
    "checkpoint_interval=5\n",
    "epochs=start_epoch\n",
    "n_epochs=500\n",
    "sample_interval=1000\n",
    "\n",
    "# ingnored when opt.mode=='S'\n",
    "psnr_list = [] \n",
    "prev_time = time.time()\n",
    "\n",
    "for epoch in range(epochs,n_epochs):\n",
    "    for i, batch in enumerate(loader):\n",
    "\n",
    "        # Model inputs\n",
    "        real_A = Variable(batch[0]).cuda() \n",
    "        real_B = Variable(batch[1]).cuda()\n",
    "        #real_A1=Tensor(np.array(split(real_A)).astype(dtype))\n",
    "        #real_B1=Tensor(np.array(split(real_B)).astype(dtype))\n",
    "        real_A1=split(real_A)\n",
    "        real_B1=split(real_B)\n",
    "        #print(len(real_A1))\n",
    "        #print(len(real_B1))\n",
    "        #real_A1= torch.tensor([item.cpu().detach().numpy() for item in real_A1]).cuda() \n",
    "        #real_B1= torch.tensor([item.cpu().detach().numpy() for item in real_B1]).cuda() \n",
    "\n",
    "\n",
    "        # Adversarial ground truths\n",
    "        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)#全1\n",
    "        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)#全0\n",
    "\n",
    "        # ------------------\n",
    "        #  Train Generators\n",
    "        # ------------------\n",
    "\n",
    "        optimizer_G.zero_grad()\n",
    "\n",
    "        # GAN loss\n",
    "        fake_B = generator(real_A)\n",
    "        #fake_B1 = list(map(lambda x: x.detach(), fake_B))\n",
    "        pred_fake = discriminator(fake_B, real_A1)\n",
    "        loss_GAN = criterion_GAN(pred_fake, valid)\n",
    "        # Pixel-wise loss\n",
    "        loss_pixel = (criterion_pixelwise(fake_B[0], real_B1[0])+\\\n",
    "                      criterion_pixelwise(fake_B[1], real_B1[1])+\\\n",
    "                      criterion_pixelwise(fake_B[2], real_B1[2])+\\\n",
    "                      criterion_pixelwise(fake_B[3], real_B1[3]))/4.0\n",
    "        loss_ssim=   -(SSIM(fake_B[0], real_B1[0])+\\\n",
    "                      SSIM(fake_B[1], real_B1[1])+\\\n",
    "                      SSIM(fake_B[2], real_B1[2])+\\\n",
    "                      SSIM(fake_B[3], real_B1[3]))/4.0\n",
    "        ssim_value = - loss_ssim.item()\n",
    "        loss_con = (L_vgg(fake_B[0], real_B1[0])+\\\n",
    "                    L_vgg(fake_B[1], real_B1[1])+\\\n",
    "                    L_vgg(fake_B[2], real_B1[2])+\\\n",
    "                    L_vgg(fake_B[3], real_B1[3]))/4.0\n",
    "        \n",
    "        \n",
    "\n",
    "        # Total loss\n",
    "        loss_G = loss_GAN + lambda_pixel * loss_pixel+100*loss_ssim+lambda_con*loss_con\n",
    "\n",
    "        loss_G.backward(retain_graph=True)\n",
    "\n",
    "        optimizer_G.step()\n",
    "\n",
    "        # ---------------------\n",
    "        #  Train Discriminator\n",
    "        # ---------------------\n",
    "\n",
    "        optimizer_D.zero_grad()\n",
    "\n",
    "        # Real loss\n",
    "        pred_real = discriminator(real_B1, real_A1)\n",
    "        loss_real = criterion_GAN(pred_real, valid)\n",
    "\n",
    "        # Fake loss\n",
    "        fake_B[0]=fake_B[0].detach()\n",
    "        fake_B[1]=fake_B[1].detach()\n",
    "        fake_B[2]=fake_B[2].detach()\n",
    "        fake_B[3]=fake_B[3].detach()\n",
    "        pred_fake1 = discriminator(fake_B, real_A1)\n",
    "        loss_fake = criterion_GAN(pred_fake1, fake)\n",
    "\n",
    "        # Total loss\n",
    "        loss_D = 0.5 * (loss_real + loss_fake)\n",
    "        #loss_D=loss_real\n",
    "\n",
    "        loss_D.backward(retain_graph=True)\n",
    "        optimizer_D.step()\n",
    "\n",
    "        # --------------\n",
    "        #  Log Progress\n",
    "        # --------------\n",
    "\n",
    "        # Determine approximate time left\n",
    "        batches_done = epoch * len(loader) + i\n",
    "        batches_left = n_epochs * len(loader) - batches_done\n",
    "        out_train= torch.clamp(fake_B[3], 0., 1.) \n",
    "        psnr_train = batch_PSNR(out_train,real_B, 1.)\n",
    "        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))\n",
    "        prev_time = time.time()\n",
    "\n",
    "        # Print log\n",
    "        sys.stdout.write(\n",
    "            \"\\r[Epoch %d/%d] [Batch %d/%d][PSNR: %f] [SSIM: %f] [D loss: %f] [G loss: %f], [pixel: %f],[VGG_loss: %f], [adv: %f] ETA: %s\"\n",
    "            % (\n",
    "                epoch,\n",
    "                n_epochs,\n",
    "                i,\n",
    "                len(loader),\n",
    "                psnr_train,\n",
    "                ssim_value,\n",
    "                loss_D.item(),\n",
    "                loss_G.item(),\n",
    "                0.1*loss_pixel.item(),\n",
    "                100*loss_con.item(),\n",
    "                loss_GAN.item(),\n",
    "                time_left,\n",
    "            )\n",
    "        )\n",
    "        \n",
    "\n",
    "        # If at sample interval save image\n",
    "        if batches_done % sample_interval == 0:\n",
    "            sample_images(batches_done)\n",
    "            csv_writer1.writerow([str(psnr_train)])\n",
    "            csv_writer2.writerow([str(ssim_value)])\n",
    "            \n",
    "            \n",
    "    scheduler_G.step()\n",
    "    scheduler_D.step()\n",
    "    if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:\n",
    "        # Save model checkpoints\n",
    "        torch.save(generator.state_dict(), \"saved_models/G/generator_%d.pth\" % (epoch))\n",
    "        torch.save(discriminator.state_dict(), \"saved_models/D/discriminator_%d.pth\" % (epoch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
